__Author__ = "Peter Herman"
__Project__ = "internet_and_trade"
__Created__ = "Sept 20, 2022"
__Description__ = '''This script performs a simulation of internet connectivity using revised estimates (terms split 
into income levels) and the Eora trade data. It also includes a new collection of post-estimations analsysis (GDP/Trade
changes by country-groups)'''


import pandas as pd
import gegravity as ge
import gme as gme
from warnings import warn
import re as regex
from math import ceil


# ----
# Specification
# ----
# Set project directory
root_directory = "D:\work\Peter_Herman\projects\internet_and_trade_public_files\\"

# Define location of data inputs
estimation_data_local = "{}all_xvalues_aggregate.csv".format(root_directory)
fe_est_local = "{}all_countrypairFE_aggregate.csv".format(root_directory)
betas_local = "{}beta_ests_aggregate.csv".format(root_directory)
internet_local = "{}WB_internet.csv".format(root_directory)
eora_trade_local = "{}eora_trade_2016.csv".format(root_directory)
country_names_local = "{}country_year_coverage_v2_0.csv".format(root_directory)

# Name of model to use in model outputs
version = 'nigeria'
# Base year to use
base_year = 2016

# Column name to use as trade values from Eora dataset
trade_var = 'eora_trade_total_$M' 

# Location to save model results
save_path = "{}".format(root_directory)


# High and high-middle income level countries
high_income_cntry = ["AND", "ATG", "ABW", "AUS", "AUT", "BHS", "BHR", "BRB", "BEL", "BMU", "VGB", "BRN",
                     "CAN", "CYM", "CHL", "HRV", "CUW", "CYP", "CZE", "DNK", "EST", "FRO", "FIN", "FRA", "PYF", "DEU",
                     "GIB", "GRC", "GRL", "GUM", "HKG", "HUN", "ISL", "IRL", "IMN", "ISR", "ITA", "JPN", "KOR", "KWT",
                     "LVA", "LIE", "LTU", "LUX", "MAC", "MLT", "MCO", "NRU", "NLD", "NCL", "NZL", "MNP", "NOR", "OMN",
                     "PLW", "POL", "PRT", "PRI", "QAT", "SMR", "SAU", "SYC", "SGP", "SXM", "SVK", "SVN", "ESP", "KNA",
                     "MAF", "SWE", "CHE", "TWN", "TTO", "TCA", "ARE", "GBR", "USA", "URY", "VIR", "ALB", "FJI", "NAM",
                     "ASM", "GAB", "MKD", "ARG", "GEO", "PLW", "ARM", "GRD", "PRY", "AZE", "GTM", "PER", "BLR", "GUY",
                     "RUS", "BLZ", "IRQ", "SRB", "BIH", "JAM", "ZAF", "BWA", "JOR", "LCA", "BRA", "KAZ", "VCT", "BLG",
                     "KSV", "SUR", "CHN", "LBY", "THA", "COL", "MYS", "TON", "CRI", "MDV", "TUR", "CUB", "MHL", "TKM",
                     "DMA", "MUS", "TUV", "DOM", "MEX", "GNQ", "MDA", "ECU", "MNE", "ROU", "PAN"]




# ----
# Prep Input Data
# ----

raw_input_data = pd.read_csv(estimation_data_local)

# Determine number of zeros
zeros = raw_input_data[['trade']].copy()
zeros['zeros'] = 1
zeros.loc[zeros['trade']>0,'zeros']=0
zeros.describe()

# Check for duplicate values
dups = raw_input_data.loc[raw_input_data.duplicated(['importer','exporter','year']),:]

# Select base year data
input_data = raw_input_data.loc[raw_input_data['year']==base_year,:].copy()
input_data.rename(columns = {'trade':'itpd_trade_total'}, inplace = True)

data_summary = input_data.describe()

# Add eora trade info (data is in $ Millions)
eora_trade = pd.read_csv(eora_trade_local)
eora_codes = eora_trade['exporter'].unique().tolist()
itpd_codes = input_data['exporter'].unique().tolist()

# Combine Eora with the covariates from the econometric estimations
input_data = input_data.merge(eora_trade, how = 'left', on = ['exporter','importer'],  validate = '1:1')

# Top trading countries based on ITPD-E flows 
top_countries = ['USA', 'JPN', 'DEU', 'FRA', 'CHN', 'IND', 'ITA', 'ESP', 'NLD', 'KOR', 'MEX', 'CHE', 'BEL', 'HKG',
                 'GBR', 'POL', 'AUT', 'NGA', 'CAN', 'DNK', 'CZE', 'ISR', 'CHL', 'RUS', 'FIN', 'SGP', 'IDN', 'ROU',
                 'BRA', 'PRT', 'IRL', 'MYS', 'AUS', 'TUR', 'THA', 'VNM', 'PER', 'GRC', 'LUX', 'HUN', 'SWE', 'PAK',
                 'SVK', 'SAU', 'UKR', 'PHL', 'ZAF', 'NOR', 'BGR', 'IRN', 'ECU', 'KAZ', 'ARG', 'SVN', 'LKA', 'CRI',
                 'EGY', 'LTU', 'BGD'] +["KHM"]

# Create square panel
top_country_pairs = [(i,j) for i in top_countries for j in top_countries]
top_country_pairs = pd.DataFrame(top_country_pairs, columns = ['exporter', 'importer'])
sub_panel = top_country_pairs.merge(input_data, how = 'left', on = ['exporter', 'importer'], validate = '1:1')


# Determine share of trade represented
total_trade = input_data[trade_var].sum()
sub_trade = sub_panel[trade_var].sum()
trade_covered = sub_trade/total_trade

# Check that the data dimensions are correct
sub_panel_missing = sub_panel.loc[sub_panel.isna().any(axis = 1),:]
if sub_panel_missing.shape[0]>0:
    raise ValueError("Missing some input data")

# Create output and expenditure
expenditures = sub_panel.groupby('importer').agg({trade_var:sum}).reset_index()
expenditures.columns = ['importer', 'expenditure']
outputs = sub_panel.groupby('exporter').agg({trade_var:sum}).reset_index()
outputs.columns = ['exporter', 'output']

# Add outputs and expenditures to input panel
sub_panel = sub_panel.merge(expenditures, how = 'outer', on = 'importer', validate='m:1')
sub_panel = sub_panel.merge(outputs, how = 'outer', on = 'exporter', validate='m:1')

# Check for any sub-sample observations that weren't used in estimation
sum_info = sub_panel.describe()
est_indicators = [col for col in sum_info.columns if col.startswith('_est')]
for col in est_indicators:
    if sum_info.loc['min',col] <1:
        warn("{} has some omitted rows".format(col))

# ----
# Prep Coefficient Estimates
# ----

# Set some column labels
regression_names = ['connect_provision']
reg_spec = 'connect_provision'

# Load pair fixed effect estimates
fe_ests = pd.read_csv(fe_est_local)
for side in ['importer', 'exporter']:
    fe_ests = fe_ests.loc[fe_ests[side].isin(top_countries),:]
fe_ests.columns = [fe_ests.columns[0], fe_ests.columns[1]] + regression_names
fe_ests['var'] = fe_ests['exporter']+'_'+fe_ests['importer']

# Load other coefficient estimates
beta_ests = pd.read_csv(betas_local, header=None)
beta_ests.columns = ['var']+ regression_names
# Drop rows containing standard errors instead of coefficient estimates
beta_ests = beta_ests.loc[~beta_ests['var'].isna(),:]
for reg in regression_names:
    beta_ests[reg] = beta_ests[reg].astype(float)


# Combine cost vars into a single dataframe
select_cols = ['var', reg_spec]
params = pd.concat([beta_ests[select_cols], fe_ests[select_cols]],axis =0)
# Drop vars without parameter ests
params = params.loc[~params[reg_spec].isna(),:]

# Create gegravity cost object
cost_object = ge.CostCoeffs(params, identifier_col='var', coeff_col=reg_spec)

# Get a list of cost variables to use
cost_vars = params['var'].tolist()
# Remove all border-year fixed effects then add back in the baseline year's term
cost_vars = [var for var in cost_vars if not regex.match('foreign_\d\d\d\d',var)]
# Drop Constant as pair fixed effects were not estimated with constant
cost_vars = [var for var in cost_vars if var != '_cons']
cost_vars.append('foreign_{}'.format(base_year))

# Check that all required pair fixed effects exist
needed_fes = (sub_panel['exporter']+'_'+sub_panel['importer']).tolist()
missing = [fe for fe in needed_fes if fe not in cost_vars]
if len(missing)>0:
    raise ValueError('There are country pairs without matching pair fixed effects.')


# ---
# Prep Cost Data
# ---

# Recreate dummies for pair fixed effects
fe_dummies = pd.get_dummies(fe_ests[['exporter', 'importer', 'var']], columns = ['var'], prefix='', prefix_sep='')

# add fixed effects to other input data
use_panel = sub_panel.merge(fe_dummies, how = 'outer', on = ['exporter','importer'], validate='1:1', indicator = True)

# Check for perfect 1:1 merge
fe_validator = use_panel['_merge'].unique()
for ind in fe_validator:
    if ind != 'both':
        raise ValueError('Fixed effect merge with input data not a perfect fit.')



# ----
# Prep GE inputs
# ----
data_summary = use_panel.describe().transpose()

# Create GME data object
gme_data = gme.EstimationData(use_panel, imp_var_name='importer', exp_var_name='exporter', year_var_name='year',
                              trade_var_name=trade_var)
# Create GME gravity model
gme_model = gme.EstimationModel(gme_data, lhs_var='trade', rhs_var=cost_vars)


# -----
# Create Internet connectivity GE model
# -----

# Define GE model
connectivity_model = ge.OneSectorGE(gme_model, year='2016', reference_importer='DEU', expend_var_name='expenditure',
                                     output_var_name='output', sigma = 7, cost_variables=cost_vars, cost_coeff_values=cost_object)

# Check for candidate OMR rescale values
# omr_eval = international_model.check_omr_rescale()

# Solve for baseline model
connectivity_model.build_baseline(omr_rescale=1)



# --- Define Experiment ---

# Create copy of the baseline data
connect_exper = connectivity_model.baseline_data.copy()

# Overwrite Nigeria's connectivity values with Brazil's
for var in ['high_int_ex','low_int_ex']:
    for cntry in top_countries:
        # Replace Nigeria as importer
        connect_exper.loc[(connect_exper['importer'] == 'NGA') & (connect_exper['exporter'] == cntry), var] = \
            connect_exper.loc[(connect_exper['importer'] == 'BRA') & (connect_exper['exporter'] == cntry), var].item()
        # Replace Nigeria as exporter
        connect_exper.loc[(connect_exper['exporter'] == 'NGA') & (connect_exper['importer'] == cntry), var] = \
            connect_exper.loc[(connect_exper['exporter'] == 'BRA') & (connect_exper['importer'] == cntry), var].item()
    # Replace Nigeria domestic with Brazil's domestic
    connect_exper.loc[(connect_exper['exporter'] == 'NGA') & (connect_exper['importer'] == 'NGA'), var] = \
        connect_exper.loc[(connect_exper['exporter'] == 'BRA') & (connect_exper['importer'] == 'BRA'), var].item()

# Supply counterfactual data to the GE model
connectivity_model.define_experiment(connect_exper)

# --- Run Counterfactual ---

connectivity_model.simulate()

# Retrieve country-level results
connectivity_results = connectivity_model.country_results
# Examine the changes in trade costs at country and bilateral level
connectivity_shock = connectivity_model.trade_weighted_shock()
connectivity_shock_bilat = connectivity_model.trade_weighted_shock('bilateral')





# ---
# Format and export main results table 7
# ---

# Load country names to use instead of ISO codes
country_names = pd.read_csv(country_names_local)
# Replace some specific country names
country_names.replace({'Germany; West Germany':'Germany', 'Czech Republic':'Czechia', 'Egypt, Arab Rep.':'Egypt',
                       'Korea, South':'South Korea', 'Ceylon; Sri Lanka':'Sri Lanka', 'Malaya; Malaysia':'Malaysia',
                       'South Vietnam; Vietnam':'Vietnam', "Cambodia; Khmer Republic; Kampuchea":"Cambodia"}, inplace=True)

# Define a function to reformat a table of simulation results
def create_table(country_results, wide=True, path:str = None, round:int = 2):
    # Select outcomes to report
    results_subset = country_results[['foreign exports change (percent)', 'GDP change (percent)' ]]
    # Replace ISOs with country names
    table = country_names.merge(results_subset, how ='right', right_index=True, left_on='iso3')
    table.drop(['earliest_year', 'latest_year', 'iso3'], axis = 1, inplace=True)
    # Re-alphabetize 
    table.sort_values('country', inplace = True)
    if wide:
        #Split into two columns of countries instead of 1
        table_wide = pd.concat([table.head(30).reset_index(), pd.DataFrame([''] * 30),
                                table.tail(30).reset_index()], axis = 1)
        table_wide.drop('index', axis = 1, inplace=True)
        table_out = table_wide
    else:
        table_out = table

    if path:
        # Write to file if path is supplied using LaTeX notation
        table_out.to_csv(path, sep='&', lineterminator='\\\\\n', index=False,
                          float_format = '%.{}f'.format(round))
    return table_out

# Create Nigeria Connectivity simulation results
create_table(connectivity_results, path = save_path + "nigeria_results.tex")


# -----
# Post Analysis for Paper discussion
# -----

# GDP growth
nga_gdp = connectivity_model.country_set['NGA'].baseline_gdp * connectivity_model.country_set['NGA'].gdp_change/100

# Value of Total Export growth
connectivity_levels = connectivity_model.calculate_levels()
connectivity_levels['export_change'] = connectivity_levels['experiment observed foreign exports'] - connectivity_levels['baseline observed foreign exports']

# Check bilateral trade impacts (percent)
connectivity_bilateral = connectivity_model.bilateral_trade_results

# Check bilateral trade impacts (levels)
connectivity_b_levels = connectivity_model.calculate_levels('bilateral')
connectivity_b_levels.sort_values('trade change (observed level)', inplace = True)
nigeria_imports = connectivity_b_levels.loc[connectivity_b_levels['importer']=='NGA',:]
nigeria_exports = connectivity_b_levels.loc[connectivity_b_levels['exporter']=='NGA',:]

def get_bilateral(results, country):
    return results.loc[(results['importer']==country) | (results['exporter']==country),:].copy()

# Effects on third party trade/diversion
chn_con_trade = get_bilateral(connectivity_b_levels, 'CHN')
ind_con_trade = get_bilateral(connectivity_b_levels, 'IND')
idn_con_trade = get_bilateral(connectivity_b_levels, 'IDN')


# ---
# Compute Imports and exports from/to Nigeria by group for table 8
# ---

gdp_outcomes = list()
# Compute GDP change in levels for each country
for name, cntry in connectivity_model.country_set.items():
    # Compute counterfactual GDP level
    counter_gdp = cntry.baseline_gdp * (1+cntry.gdp_change/100)
    # Compute GDP change in levels
    gdp_diff = counter_gdp - cntry.baseline_gdp
    gdp_outcomes.append((name, cntry.baseline_gdp, counter_gdp, gdp_diff))
# Combine values for each individual country
gdp_outcomes = pd.DataFrame(gdp_outcomes, columns = ['country', 'baseline_gdp', 'counter_gdp', 'gdp_difference'])
gdp_outcomes.sort_values('gdp_difference', inplace = True)


# Define function to compute aggregate GDP effects by country-group
def compute_group_gdp(members, results, label):
    group_results = results.loc[results['country'].isin(members),:].copy()
    # Exclude Nigeria if normally in the group
    group_results = group_results.loc[group_results['country']!='NGA',:].copy()
    # Aggregate and define $ changes
    agg_results = group_results.sum()
    agg_results['$M gdp change'] =  agg_results['counter_gdp'] - agg_results['baseline_gdp']
    agg_results['% gdp change'] = 100* agg_results['$M gdp change']/agg_results['baseline_gdp']
    agg_results['country'] = label
    return agg_results

# --- GDP effects by income level
high_income_gdp = compute_group_gdp(high_income_cntry, gdp_outcomes, 'high income')
low_income_cntry = [cntry for cntry in top_countries if cntry not in high_income_cntry]
low_income_gdp = compute_group_gdp(low_income_cntry, gdp_outcomes, 'low income')


# --- GDP effects by internet level
# Get a list of countries with above median internet use
internet_data = pd.read_csv(internet_local)
internet_data = internet_data.loc[internet_data['Time']==2016,:]
internet_stats = internet_data['Individuals using the Internet (% of population) [IT.NET.USER.ZS]'].describe()
high_internet = internet_data.loc[internet_data['Individuals using the Internet (% of population) [IT.NET.USER.ZS]']>internet_stats['50%'],'Country Code'].tolist()
high_internet.sort()

# Divide model countries into high and low internet use
high_int_model = [cntry for cntry in top_countries if cntry in high_internet]
low_int_model = [cntry for cntry in top_countries if cntry not in high_internet]

# Get GDP outcomes
high_int_gdp = compute_group_gdp(high_int_model, gdp_outcomes, 'high internet')
low_int_gdp = compute_group_gdp(low_int_model, gdp_outcomes, 'low internet')


# --- GDP effects by specific countries or regions
EU = use_panel.loc[(use_panel['member_eu_joint']==1),'importer'].unique()
eu_gdp = compute_group_gdp(EU, gdp_outcomes, 'EU')
usa_gdp = compute_group_gdp(['USA'], gdp_outcomes, 'USA')
chn_gdp = compute_group_gdp(['CHN'], gdp_outcomes, 'CHN')


gdp_combined = pd.concat([high_income_gdp, low_income_gdp, high_int_gdp, low_int_gdp, eu_gdp, usa_gdp, chn_gdp], axis = 1)
gdp_combined = gdp_combined.transpose()



# ----
# Trade with Nigeria by group
# ----

# Collect bilateral results in changes and levels
bilat_levels = connectivity_model.bilateral_trade_results.reset_index()
bilat_calculated = connectivity_model.calculate_levels(how='bilateral')
bilat_levels = bilat_levels.merge(bilat_calculated[['exporter','importer','baseline observed trade']], how = 'outer', on = ['exporter','importer'])


trade_by_group = dict()
for side in ['importer', 'exporter']:
    if side == 'importer':
        other_side = 'exporter'
        percent_label = '% change in imports from Nigeria'
        dollar_label = '$M change in imports from Nigeria'
    else:
        other_side = 'importer'
        percent_label = '% change in exports to Nigeria'
        dollar_label = '$M change in exports to Nigeria'
    
    # Compute trade changes in  levels and percent by group for each group
    group_ests = list()
    for cntry in [(high_income_cntry, 'high income'), (low_income_cntry, 'low income'), (high_int_model, 'high internet'),
                  (low_int_model, 'low internet'), (EU,'EU'), (['USA'], 'USA'), (['CHN'],'CHN')]:
        group = bilat_levels.loc[bilat_levels[side].isin(cntry[0]),:]
        # Drop domestic trade 
        group = group.loc[group['importer']!= group['exporter'],:]
        # Keep only Nigeria as a partner
        group = group.loc[group[other_side]=='NGA',:]
        # Replace country codes with group name
        group[side] = cntry[1]
        # Sum values over group
        group = group.groupby(side).agg({'baseline modeled trade':'sum',
                                                           'experiment trade':'sum', 'baseline observed trade':'sum'}).reset_index()
        #  Define percent change using modeled trade values
        group[percent_label] = 100*(group['experiment trade'] - group['baseline modeled trade'])/group['baseline modeled trade']
        # Compute level change using observed levels and estimated/modeled change
        group[dollar_label] = (group[percent_label]/100)*group['baseline observed trade']
        # Relabel column and add to list of group calculations
        group.rename(columns = {side:'country'}, inplace = True)
        group_ests.append(group)
    # Combine computations of all groups
    combined = pd.concat(group_ests, axis = 0)
    trade_by_group[side] = combined

# Merge exporter and importer values
all_trade = trade_by_group['exporter'].merge(trade_by_group['importer'], on = ['country'])

# Combine group trade values with GDP values
trade_gdp_effects = all_trade.merge(gdp_combined, how = 'outer', on = 'country')
tex_trade_gdp = trade_gdp_effects[['country', '% change in imports from Nigeria', '$M change in imports from Nigeria',
                                   '% change in exports to Nigeria', '$M change in exports to Nigeria',
                                   '% gdp change', '$M gdp change']].copy()
# Round values for paper tables
for col in tex_trade_gdp.columns:
    if col in ['% gdp change']:
        tex_trade_gdp[col] = tex_trade_gdp[col].astype(float).round(4)
    if col in ['$M change in imports from Nigeria', '$M change in exports to Nigeria', '$M gdp change']:
        tex_trade_gdp[col] = tex_trade_gdp[col].astype(float).round(0)
    if col in ['% change in imports from Nigeria', '% change in exports to Nigeria']:
        tex_trade_gdp[col] = tex_trade_gdp[col].astype(float).round(1)

# Save to csv and as a TeX formatted table
tex_trade_gdp.to_csv("{}Nigeria_trade_gdp_impacts_by_group.tex".format(save_path), sep='&', lineterminator='\\\\\n', index = False)
tex_trade_gdp.to_csv("{}Nigeria_trade_gdp_impacts_by_group.csv".format(save_path), index = False)









